{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Demo - Adversarial Training (MNIST)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "\n",
    "import torchvision.utils\n",
    "from torchvision import models\n",
    "import torchvision.datasets as dsets\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "import torchattacks\n",
    "from torchattacks import PGD, FGSM\n",
    "\n",
    "from models import CNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist_train = dsets.MNIST(root='./data/',\n",
    "                          train=True,\n",
    "                          transform=transforms.ToTensor(),\n",
    "                          download=True)\n",
    "\n",
    "mnist_test = dsets.MNIST(root='./data/',\n",
    "                         train=False,\n",
    "                         transform=transforms.ToTensor(),\n",
    "                         download=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 128\n",
    "\n",
    "train_loader  = torch.utils.data.DataLoader(dataset=mnist_train,\n",
    "                                           batch_size=batch_size,\n",
    "                                           shuffle=False)\n",
    "\n",
    "test_loader = torch.utils.data.DataLoader(dataset=mnist_test,\n",
    "                                         batch_size=batch_size,\n",
    "                                         shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Define Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = CNN().cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "atk = PGD(model, eps=0.3, alpha=0.1, steps=7)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Train Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/5], lter [100/468], Loss: 2.2063\n",
      "Epoch [1/5], lter [200/468], Loss: 1.8595\n",
      "Epoch [1/5], lter [300/468], Loss: 1.8766\n",
      "Epoch [1/5], lter [400/468], Loss: 1.6352\n",
      "Epoch [2/5], lter [100/468], Loss: 1.4859\n",
      "Epoch [2/5], lter [200/468], Loss: 1.1941\n",
      "Epoch [2/5], lter [300/468], Loss: 1.4796\n",
      "Epoch [2/5], lter [400/468], Loss: 1.1723\n",
      "Epoch [3/5], lter [100/468], Loss: 1.0888\n",
      "Epoch [3/5], lter [200/468], Loss: 0.8527\n",
      "Epoch [3/5], lter [300/468], Loss: 1.0899\n",
      "Epoch [3/5], lter [400/468], Loss: 0.8288\n",
      "Epoch [4/5], lter [100/468], Loss: 0.9201\n",
      "Epoch [4/5], lter [200/468], Loss: 0.6490\n",
      "Epoch [4/5], lter [300/468], Loss: 0.8750\n",
      "Epoch [4/5], lter [400/468], Loss: 0.6238\n",
      "Epoch [5/5], lter [100/468], Loss: 0.7873\n",
      "Epoch [5/5], lter [200/468], Loss: 0.5471\n",
      "Epoch [5/5], lter [300/468], Loss: 0.8011\n",
      "Epoch [5/5], lter [400/468], Loss: 0.5021\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(num_epochs):\n",
    "\n",
    "    total_batch = len(mnist_train) // batch_size\n",
    "    \n",
    "    for i, (batch_images, batch_labels) in enumerate(train_loader):\n",
    "        X = atk(batch_images, batch_labels).cuda()\n",
    "        Y = batch_labels.cuda()\n",
    "\n",
    "        pre = model(X)\n",
    "        cost = loss(pre, Y)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        cost.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        if (i+1) % 100 == 0:\n",
    "            print('Epoch [%d/%d], lter [%d/%d], Loss: %.4f'\n",
    "                 %(epoch+1, num_epochs, i+1, total_batch, cost.item()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Test Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.1 Standard Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Standard accuracy: 96.14 %\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "\n",
    "correct = 0\n",
    "total = 0\n",
    "\n",
    "for images, labels in test_loader:\n",
    "    \n",
    "    images = images.cuda()\n",
    "    outputs = model(images)\n",
    "    \n",
    "    _, predicted = torch.max(outputs.data, 1)\n",
    "    \n",
    "    total += labels.size(0)\n",
    "    correct += (predicted == labels.cuda()).sum()\n",
    "    \n",
    "print('Standard accuracy: %.2f %%' % (100 * float(correct) / total))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.2 Robust Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Robust accuracy: 85.22 %\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "\n",
    "correct = 0\n",
    "total = 0\n",
    "\n",
    "atk = FGSM(model, eps=0.3)\n",
    "\n",
    "for images, labels in test_loader:\n",
    "    \n",
    "    images = atk(images, labels).cuda()\n",
    "    outputs = model(images)\n",
    "    \n",
    "    _, predicted = torch.max(outputs.data, 1)\n",
    "    \n",
    "    total += labels.size(0)\n",
    "    correct += (predicted == labels.cuda()).sum()\n",
    "    \n",
    "print('Robust accuracy: %.2f %%' % (100 * float(correct) / total))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
